1 Preparations

1.1 Construction sites

First order: Harmonise results a bit more.

  • Features used/pre-processing
  • DML multinomial?
  • Fitted propensities close to zero?
  • Do I overfit?

1.2 Notation

  • \(D\) Treatment indicator (binary or multiarm)
  • \(y\) outcome
  • \(X\) features/controls

1.3 Seed

rm(list=ls())
seed<-1909

1.4 Libraries

# loading & modifying data
library("readr")         # to read the data
library("dplyr")         # to manipulate data
library("fastDummies")   # create dummies
# charts & tables
library("ggplot2")       # to create charts
library("patchwork")     # to combine charts
library("flextable")     # design tables
library("modelsummary")  # structure tables
library("kableExtra")    # design table
library("estimatr")
library("ggpubr")
# regression & analysis
library("fixest")        # high dimensional FE
library("skimr")         # skim the data
# machine learning
library("policytree")    # policy tree (Athey & Wager, 2021)
library("grf")           # causal forest
library("rsample")       # data splitting 
library("randomForest")  # Traditional Random Forests
library("mlr3")          # learners
library("mlr3learners")  # learners
library("gbm")           # Generalized Boosted Regression
library("DoubleML")      # Double ML

1.5 Load and prepare data

1.5.1 Load data

# load full dataset

df_repl<-read_delim("../data/FARS-data-full-sample.txt",delim = "\t")%>%
              filter(year<2004)%>%
              select(-starts_with("imp"))
# load small dataset
df_sel<-read_delim("../data/FARS-data-selection-sample.txt",delim = "\t")%>%
              filter(year<2004)%>%
              select(-starts_with("imp"))
# remove rows with missing cases
df_repl<-df_repl[complete.cases(df_repl), ]
df_sel<-df_sel[complete.cases(df_sel), ]

# print number of obs
print(paste('Number of observations in the data:',nrow(df_repl),' (full sample);',nrow(df_sel), ' (selected/causal sample)'))
## [1] "Number of observations in the data: 38455  (full sample); 10328  (selected/causal sample)"

1.5.2 Manipulate data

# Treatment indicators
df_repl<-df_repl%>%mutate(D=case_when(lapshould==1~"LapShoulderSeat",lapbelt==1~"Lapbelt",
                                      childseat==1~"Childseat",TRUE~"NONE"),
                          D=factor(D,levels=c("NONE","Lapbelt","LapShoulderSeat","Childseat")),
                          Dbinary=case_when(lapshould==1~1,lapbelt==1~1,childseat==1~1,TRUE~0))
df_sel <-df_sel %>%mutate(D=case_when(lapshould==1~"LapShoulderSeat",lapbelt==1~"Lapbelt",
                                    childseat==1~"Childseat",TRUE~"NONE"),
                         D=factor(D,levels=c("NONE","Lapbelt","LapShoulderSeat","Childseat")),
                         Dbinary=case_when(lapshould==1~1,lapbelt==1~1,childseat==1~1,TRUE~0))
# Convert categorical to indicators
df_repl<-dummy_cols(df_repl%>%select(-restraint))%>%select(-starts_with("D_"),-crashtm)
df_sel<-dummy_cols(df_sel%>%select(-restraint))%>%select(-starts_with("D_"),-crashtm)
#df_repl<-df_repl%>%mutate(day=ifelse(crashtm=="1_day",1,0),night=ifelse(crashtm=="2_night",1,0),morn=ifelse(crashtm=="3_morn",1,0))
#df_sel<- df_sel %>%mutate(day=ifelse(crashtm=="1_day",1,0),night=ifelse(crashtm=="2_night",1,0),morn=ifelse(crashtm=="3_morn",1,0))
# Select variables
df_repl<-df_repl%>%select(splmU55,thoulbs_I,numcrash,weekend,lowviol,highviol,ruralrd,frimp,suv,death,D,Dbinary,modelyr,age,year)
df_sel<- df_sel %>%select(splmU55,thoulbs_I,numcrash,weekend,lowviol,highviol,ruralrd,frimp,suv,death,D,Dbinary,modelyr,age,year)

# Training and test data
set.seed(seed)
df_repl_split <- initial_split(df_repl, prop = .5)
df_repl_train <- training(df_repl_split)
df_repl_test  <- testing(df_repl_split)
df_sel_split <- initial_split(df_sel, prop = .5)
df_sel_train <- training(df_sel_split)
df_sel_test  <- testing(df_sel_split)
# X Matrices
X_repl_train<-as.matrix(df_repl_train%>%select(-death,-D,-Dbinary))
X_repl_test<- as.matrix(df_repl_test%>%select(-death,-D,-Dbinary))
X_sel_train<- as.matrix(df_sel_train%>%select(-death,-D,-Dbinary))
X_sel_test<-  as.matrix(df_sel_test%>%select(-death,-D,-Dbinary))
X_repl_train_nocontrols<-as.matrix(rep(1,nrow(X_repl_train)))
X_repl_test_nocontrols<- as.matrix(rep(1,nrow(X_repl_test)))
X_sel_train_nocontrols<- as.matrix(rep(1,nrow(X_sel_train)))
X_sel_test_nocontrols<-  as.matrix(rep(1,nrow(X_sel_test)))
# D matrices
D_repl_train<-factor(df_repl_train$D,levels=c("NONE","Lapbelt","LapShoulderSeat","Childseat"))
D_repl_test<-factor(df_repl_train$D,levels=c("NONE","Lapbelt","LapShoulderSeat","Childseat"))
D_sel_train<-factor(df_sel_train$D,levels=c("NONE","Lapbelt","LapShoulderSeat","Childseat"))
D_sel_test<-factor(df_sel_train$D,levels=c("NONE","Lapbelt","LapShoulderSeat","Childseat"))
D_binary_repl_train<-as.matrix(df_repl_train%>%select(Dbinary))
D_binary_repl_test<- as.matrix(df_repl_test%>%select(Dbinary))
D_binary_sel_train<- as.matrix(df_sel_train%>%select(Dbinary))
D_binary_sel_test<-  as.matrix(df_sel_test%>%select(Dbinary))
# Y matrices
Y_repl_train<-as.matrix(df_repl_train%>%select(death))
Y_repl_test<- as.matrix(df_repl_test%>%select(death))
Y_sel_train<- as.matrix(df_sel_train%>%select(death))
Y_sel_test<-  as.matrix(df_sel_test%>%select(death))

1.6 Summary statistics

tmp <- df_sel%>%select(splmU55,thoulbs_I,modelyr,year,numcrash,weekend,lowviol,highviol,ruralrd,frimp,suv,death)
# remove missing and rescale
tmp_list <- lapply(tmp, na.omit)
tmp_list <- lapply(tmp_list, scale)

emptycol = function(x) " "
datasummary(splmU55+thoulbs_I+modelyr+year+numcrash+weekend+lowviol+highviol+ruralrd+frimp+suv+death ~ Mean + SD + Heading("Boxplot") * emptycol + Heading("Histogram") * emptycol, data = tmp) %>%
    column_spec(column = 4, image = spec_boxplot(tmp_list)) %>%
    column_spec(column = 5, image = spec_hist(tmp_list))
Mean SD Boxplot Histogram
splmU55 0.88 0.33
thoulbs_I 2.45 1.54
modelyr 1987.09 8.30
year 1993.42 7.13
numcrash 6.63 4.51
weekend 0.40 0.49
lowviol 0.29 0.45
highviol 0.08 0.27
ruralrd 0.08 0.28
frimp 0.67 0.47
suv 0.10 0.29
death 0.04 0.20

2 Double Machine Learning

2.1 Binary treatment

The next cell initializes the DML model, fits them twice (the first time without controls.

# Create  DML object
dml_data_nocontrols = double_ml_data_from_matrix(y=Y_repl_train,d=D_binary_repl_train,X_repl_train_nocontrols)
dml_data_controls = double_ml_data_from_matrix(y=Y_repl_train,d=D_binary_repl_train,X_repl_train)
# Initiate earners
lgr::get_logger("mlr3")$set_threshold("warn")
learner=lrn(eval_metric="logloss","classif.xgboost")
ml_m = learner$clone()
learner=lrn(objective ='reg:squarederror',"regr.xgboost")
ml_g = learner$clone()

# Estimate DML without controls
obj_dml = DoubleMLPLR$new(dml_data_nocontrols, ml_g=ml_g, ml_m=ml_m)
obj_dml$fit()
print("------------- No controls ------------- ")
## [1] "------------- No controls ------------- "
print(obj_dml)
## ================= DoubleMLPLR Object ==================
## 
## 
## ------------------ Data summary      ------------------
## Outcome variable: y
## Treatment variable(s): d
## Covariates: X1
## Instrument(s): 
## No. Observations: 19227
## 
## ------------------ Score & algorithm ------------------
## Score function: partialling out
## DML algorithm: dml2
## 
## ------------------ Machine learner   ------------------
## ml_g: regr.xgboost
## ml_m: classif.xgboost
## 
## ------------------ Resampling        ------------------
## No. folds: 5
## No. repeated sample splits: 1
## Apply cross-fitting: TRUE
## 
## ------------------ Fit summary       ------------------
##  Estimates and significance testing of the effect of target variables
##   Estimate. Std. Error t value Pr(>|t|)    
## d -0.100617   0.006553  -15.35   <2e-16 ***
## ---
## Signif. codes:  0 '***' 0.001 '**' 0.01 '*' 0.05 '.' 0.1 ' ' 1
# Estimate DML with controls
obj_dml = DoubleMLPLR$new(dml_data_controls, ml_g=ml_g, ml_m=ml_m)
obj_dml$fit()
cat("\n\n\n")
print("------------- With controls ------------- ")
## [1] "------------- With controls ------------- "
print(obj_dml)
## ================= DoubleMLPLR Object ==================
## 
## 
## ------------------ Data summary      ------------------
## Outcome variable: y
## Treatment variable(s): d
## Covariates: X1, X2, X3, X4, X5, X6, X7, X8, X9, X10, X11, X12
## Instrument(s): 
## No. Observations: 19227
## 
## ------------------ Score & algorithm ------------------
## Score function: partialling out
## DML algorithm: dml2
## 
## ------------------ Machine learner   ------------------
## ml_g: regr.xgboost
## ml_m: classif.xgboost
## 
## ------------------ Resampling        ------------------
## No. folds: 5
## No. repeated sample splits: 1
## Apply cross-fitting: TRUE
## 
## ------------------ Fit summary       ------------------
##  Estimates and significance testing of the effect of target variables
##   Estimate. Std. Error t value Pr(>|t|)    
## d -0.101753   0.006971   -14.6   <2e-16 ***
## ---
## Signif. codes:  0 '***' 0.001 '**' 0.01 '*' 0.05 '.' 0.1 ' ' 1

2.2 Multiarm treatment

print("------------- No controls ------------- ")
## [1] "------------- No controls ------------- "
cfnocontrols <- multi_arm_causal_forest(X=X_sel_train_nocontrols, Y=Y_sel_train, W=D_sel_train)
average_treatment_effect(cfnocontrols)
##                           estimate     std.err               contrast outcome
## Lapbelt - NONE         -0.05339767 0.005772780         Lapbelt - NONE   death
## LapShoulderSeat - NONE -0.04656707 0.005979402 LapShoulderSeat - NONE   death
## Childseat - NONE       -0.04718428 0.005736955       Childseat - NONE   death
cat("\n\n\n")
print("------------- With controls ------------- ")
## [1] "------------- With controls ------------- "
cfcontrols <- multi_arm_causal_forest(X=X_repl_train, Y=Y_repl_train, W=D_repl_train)
average_treatment_effect(cfcontrols)
##                        estimate std.err               contrast outcome
## Lapbelt - NONE              NaN     NaN         Lapbelt - NONE   death
## LapShoulderSeat - NONE      NaN     NaN LapShoulderSeat - NONE   death
## Childseat - NONE            NaN     NaN       Childseat - NONE   death

3 Causal Forest

3.1 Binary treatment

3.1.1 Non-selected

cfbinary<- causal_forest(X=X_repl_train, Y=Y_repl_train, W=D_binary_repl_train,tune.parameters = "all")
average_treatment_effect(cfbinary)
##     estimate      std.err 
## -0.113805346  0.008314345

3.1.2 Selected sample (causal)

cfbinary<- causal_forest(X=X_sel_train, Y=Y_sel_train, W=D_binary_sel_train,tune.parameters = "all")
average_treatment_effect(cfbinary)
##     estimate      std.err 
## -0.061587470  0.008554263

3.1.3 Parameter settings

cfbinary$tuning.output               
## Tuning status: default.
## This indicates tuning was attempted. However, we could not find parameters that were expected to perform better than default: 
## 
## sample.fraction: 0.5
##  mtry: 12
##  min.node.size: 5
##  honesty.fraction: 0.5
##  honesty.prune.leaves: TRUE
##  alpha: 0.05
##  imbalance.penalty: 0

3.1.4 Regression forests

Y.regforest = regression_forest(X_sel_train, Y_sel_train)
D.regforest = regression_forest(X_sel_train, D_binary_sel_train)

3.1.5 Comparison to OLS

Below I estimate the basic OLS

# Fit OLS
olsY<-lm(death~.,data=df_sel_train%>%select(-D,-Dbinary))
olsD<-lm(Dbinary~.,data=df_sel_train%>%select(-death,-D))
# Print
print("---- OLS for Y ----")
## [1] "---- OLS for Y ----"
summary(olsY)
## 
## Call:
## lm(formula = death ~ ., data = df_sel_train %>% select(-D, -Dbinary))
## 
## Residuals:
##      Min       1Q   Median       3Q      Max 
## -0.13763 -0.05059 -0.03876 -0.02876  1.00188 
## 
## Coefficients:
##               Estimate Std. Error t value Pr(>|t|)    
## (Intercept)  2.2327350  0.9015832   2.476   0.0133 *  
## splmU55     -0.0453104  0.0093381  -4.852 1.26e-06 ***
## thoulbs_I   -0.0008212  0.0018972  -0.433   0.6651    
## numcrash    -0.0007848  0.0006693  -1.173   0.2410    
## weekend      0.0105586  0.0058113   1.817   0.0693 .  
## lowviol      0.0028352  0.0062891   0.451   0.6521    
## highviol     0.0445936  0.0106970   4.169 3.11e-05 ***
## ruralrd     -0.0079654  0.0105910  -0.752   0.4520    
## frimp       -0.0023900  0.0060718  -0.394   0.6939    
## suv         -0.0187080  0.0099814  -1.874   0.0609 .  
## modelyr     -0.0007219  0.0005998  -1.204   0.2288    
## age         -0.0024101  0.0020236  -1.191   0.2337    
## year        -0.0003525  0.0007010  -0.503   0.6151    
## ---
## Signif. codes:  0 '***' 0.001 '**' 0.01 '*' 0.05 '.' 0.1 ' ' 1
## 
## Residual standard error: 0.2033 on 5151 degrees of freedom
## Multiple R-squared:  0.01085,    Adjusted R-squared:  0.008544 
## F-statistic: 4.708 on 12 and 5151 DF,  p-value: 1.085e-07
print("---- OLS for D ----")
## [1] "---- OLS for D ----"
summary(olsD)
## 
## Call:
## lm(formula = Dbinary ~ ., data = df_sel_train %>% select(-death, 
##     -D))
## 
## Residuals:
##     Min      1Q  Median      3Q     Max 
## -1.0959 -0.2654  0.0865  0.2839  0.9651 
## 
## Coefficients:
##               Estimate Std. Error t value Pr(>|t|)    
## (Intercept) -72.684428   1.778280 -40.873  < 2e-16 ***
## splmU55       0.063453   0.018418   3.445 0.000575 ***
## thoulbs_I     0.007469   0.003742   1.996 0.045984 *  
## numcrash     -0.001470   0.001320  -1.113 0.265608    
## weekend      -0.005567   0.011462  -0.486 0.627237    
## lowviol       0.043783   0.012405   3.530 0.000420 ***
## highviol     -0.044867   0.021099  -2.127 0.033508 *  
## ruralrd      -0.074257   0.020890  -3.555 0.000382 ***
## frimp        -0.034271   0.011976  -2.862 0.004232 ** 
## suv           0.018379   0.019687   0.934 0.350589    
## modelyr       0.014444   0.001183  12.210  < 2e-16 ***
## age          -0.023478   0.003991  -5.882  4.3e-09 ***
## year          0.022411   0.001383  16.208  < 2e-16 ***
## ---
## Signif. codes:  0 '***' 0.001 '**' 0.01 '*' 0.05 '.' 0.1 ' ' 1
## 
## Residual standard error: 0.401 on 5151 degrees of freedom
## Multiple R-squared:  0.3005, Adjusted R-squared:  0.2988 
## F-statistic: 184.4 on 12 and 5151 DF,  p-value: < 2.2e-16

Let us compare the out of sample performance

# R-squared
r2 <-function(preds,actual){ 
  return(1- sum((preds - actual) ^ 2)/sum((actual - mean(actual))^2))
}

r2_olsY<-r2(predict(olsY,newdata=df_sel_test),df_sel_test$death)
r2_olsD<-r2(predict(olsD,newdata=df_sel_test),df_sel_test$Dbinary)

r2_rfY<-r2(predict(Y.regforest,newdata=X_sel_test)$predictions,df_sel_test$death)
r2_rfD<-r2(predict(D.regforest,newdata=X_sel_test)$predictions,df_sel_test$Dbinary)

data.frame(Method=c("OLS","RF"),R2_D=c(r2_olsD,r2_rfD),R2_Y=c(r2_olsY,r2_rfY))
##   Method      R2_D        R2_Y
## 1    OLS 0.2922920 0.005642891
## 2     RF 0.3231286 0.012266324

3.1.6 Propensity scores

plotdata<-data.frame(what=cfbinary$W.hat)
ggplot(plotdata,aes(x=what))+geom_histogram(bins=100,fill="#f56c42",color="white")+xlim(0,1)+
  theme_minimal()

### Diagnostic tests

test_calibration(cfbinary)
## 
## Best linear fit using forest predictions (on held-out data)
## as well as the mean forest prediction as regressors, along
## with one-sided heteroskedasticity-robust (HC3) SEs:
## 
##                                Estimate Std. Error t value    Pr(>t)    
## mean.forest.prediction          1.00739    0.13639  7.3863 8.758e-14 ***
## differential.forest.prediction  0.62310    0.32952  1.8909   0.02935 *  
## ---
## Signif. codes:  0 '***' 0.001 '**' 0.01 '*' 0.05 '.' 0.1 ' ' 1

3.1.7 Influential features

# Get importance
importance=variable_importance(cfbinary)

var_imp <- data.frame(importance=importance,names=colnames(X_sel_train))
ggplot(var_imp,aes(x= reorder(names,importance),y=importance))+
  geom_bar(stat="identity",fill="#f56c42",color="white")+
  theme_minimal()+
  theme(axis.text.x = element_text(angle=45,vjust = 1, hjust=1))+
  labs(x=" ")+
  coord_flip()

### Characterising treatment effect heterogeneity

CATE distribution

# get predictions
cate<-data.frame(sample="CATEs",tau=predict(cfbinary)$predictions)
# histogram all
ggplot(cate,aes(x=tau))+
   geom_histogram(aes(y=..count../sum(..count..)),bins=100,alpha=0.95, position = "identity",
                  fill="#f56c42",color="white")+
  theme_minimal()+
  labs(title=" ",x="Conditional Average Treatment Effect",y="Density")

Plot quartiles

# Split sample in 5 groups based on cates
df_sel_train["categroup"] <- factor(ntile(predict(cfbinary)$predictions, n=4))
# calculate AIPW for each sub group
estimated_aipw_ate <- lapply(
  seq(4), function(w) {
  ate <- average_treatment_effect(cfbinary, subset = df_sel_train$categroup == w,method = "AIPW")
})
# Combine in data da frame
estimated_aipw_ate <- data.frame(do.call(rbind, estimated_aipw_ate))
estimated_aipw_ate$Ntile <- as.numeric(rownames(estimated_aipw_ate))

# create plot
ggplot(estimated_aipw_ate) +
  geom_pointrange(aes(x = Ntile, y = estimate, ymax = estimate + 1.96 * `std.err`, ymin = estimate - 1.96 * `std.err`), 
                  size = 1,
                  position = position_dodge(width = .5)) +
  theme_minimal() +
  geom_hline(yintercept=0,linetype="dashed")+
  labs(x = "Quartile", y = "AIPW ATE", title = "AIPW ATEs by  quartiles of the conditional average treatment effect")

# create table
datasummary_balance(~categroup,
                    data = sumstatdata<-df_sel_train%>%filter(categroup%in%c(1,4))%>%select(-D),
                    title = "Comparison of the first vs fourth quartile",
                    fmt= '%.3f',
                    dinm_statistic = "p.value")
Comparison of the first vs fourth quartile
1 (N=1291)
2 (N=0)
3 (N=0)
4 (N=1291)
Mean Std. Dev. Mean Std. Dev. Mean Std. Dev. Mean Std. Dev. Diff. in Means p
splmU55 0.841 0.366 0.969 0.173 0.128 0.000
thoulbs_I 2.129 1.071 3.468 1.280 1.339 0.000
numcrash 6.760 4.897 6.239 2.299 -0.521 0.001
weekend 0.388 0.487 0.394 0.489 0.006 0.747
lowviol 0.341 0.474 0.211 0.409 -0.129 0.000
highviol 0.125 0.331 0.016 0.127 -0.108 0.000
ruralrd 0.054 0.227 0.104 0.305 0.050 0.000
frimp 0.559 0.497 0.833 0.373 0.273 0.000
suv 0.030 0.171 0.156 0.363 0.126 0.000
death 0.059 0.235 0.025 0.156 -0.034 0.000
Dbinary 0.700 0.458 0.583 0.493 -0.117 0.000
modelyr 1987.620 6.594 1986.474 9.902 -1.146 0.001
age 3.572 1.282 4.129 1.459 0.556 0.000
year 1993.371 6.125 1992.638 8.486 -0.733 0.012

Now by covariates

df_sel_train["tau"]<-predict(cfbinary)$predictions
df_sel_train_col<-df_sel_train%>%
  group_by(modelyr,splmU55)%>%
  summarise(tau=mean(tau))
p1<-ggplot(df_sel_train_col,aes(x=modelyr,y=tau,color=as.factor(splmU55)))+geom_point()+
  ylim(-0.125,0)
df_sel_train_col<-df_sel_train%>%
  group_by(year,splmU55)%>%
  summarise(tau=mean(tau))
p2<-ggplot(df_sel_train_col,aes(x=year,y=tau,color=as.factor(splmU55)))+geom_point()+
  ylim(-0.125,0)+labs(y="")
df_sel_train_col<-df_sel_train%>%
  group_by(thoulbs_I)%>%
  summarise(tau=mean(tau))
p3<-ggplot(df_sel_train_col,aes(x=thoulbs_I*1000,y=tau))+geom_point()+
  ylim(-0.125,0)+labs(y="")
ggarrange(p1, p2, p3, ncol=3, nrow=1, common.legend = TRUE, legend="bottom")

CATE distribution by speed limit

# get predictions
cate<-data.frame(sample="CATEs",splmU55=df_sel_train$splmU55,tau=predict(cfbinary)$predictions)
# histogram all
ggplot(cate,aes(x=tau,fill=as.factor(splmU55),group=splmU55))+
   geom_histogram(aes(y=..count../sum(..count..)),bins=100,alpha=0.5, position = "identity",
                 color="white")+
  theme_minimal()+
  labs(title=" ",x="Conditional Average Treatment Effect",y="Density")

3.2 Multiarm treatment

3.2.1 First I reload data

# load full dataset

df_repl<-read_delim("../data/FARS-data-full-sample.txt",delim = "\t")%>%
              filter(year<2004)%>%
              select(-starts_with("imp"))
# load small dataset
df_sel<-read_delim("../data/FARS-data-selection-sample.txt",delim = "\t")%>%
              filter(year<2004)%>%
              select(-starts_with("imp"))
# remove rows with missing cases
df_repl<-df_repl[complete.cases(df_repl), ]
df_sel<-df_sel[complete.cases(df_sel), ]



# Treatment indicators
df_repl<-df_repl%>%mutate(D=case_when(lapshould==1~"LapShoulderSeat",lapbelt==1~"Lapbelt",
                                      childseat==1~"Childseat",TRUE~"NONE"),
                          D=factor(D,levels=c("NONE","Lapbelt","LapShoulderSeat","Childseat")),
                          Dbinary=case_when(lapshould==1~1,lapbelt==1~1,childseat==1~1,TRUE~0))
df_sel <-df_sel %>%mutate(D=case_when(lapshould==1~"LapShoulderSeat",lapbelt==1~"Lapbelt",
                                    childseat==1~"Childseat",TRUE~"NONE"),
                         D=factor(D,levels=c("NONE","Lapbelt","LapShoulderSeat","Childseat")),
                         Dbinary=case_when(lapshould==1~1,lapbelt==1~1,childseat==1~1,TRUE~0))
# Convert categorical to indicators
df_repl<-dummy_cols(df_repl%>%select(-restraint))%>%select(-starts_with("D_"),-crashtm)
df_sel<-dummy_cols(df_sel%>%select(-restraint))%>%select(-starts_with("D_"),-crashtm)
#df_repl<-df_repl%>%mutate(day=ifelse(crashtm=="1_day",1,0),night=ifelse(crashtm=="2_night",1,0),morn=ifelse(crashtm=="3_morn",1,0))
#df_sel<- df_sel %>%mutate(day=ifelse(crashtm=="1_day",1,0),night=ifelse(crashtm=="2_night",1,0),morn=ifelse(crashtm=="3_morn",1,0))
# Select variables
#df_repl<-df_repl%>%select(splmU55,thoulbs_I,numcrash,weekend,lowviol,highviol,ruralrd,frimp,suv,death,D,Dbinary)
#df_sel<- df_sel %>%select(splmU55,thoulbs_I,numcrash,weekend,lowviol,highviol,ruralrd,frimp,suv,death,D,Dbinary)

# Training and test data
set.seed(seed)
df_repl_split <- initial_split(df_repl, prop = .5)
df_repl_train <- training(df_repl_split)
df_repl_test  <- testing(df_repl_split)
df_sel_split <- initial_split(df_sel, prop = .5)
df_sel_train <- training(df_sel_split)
df_sel_test  <- testing(df_sel_split)
# X Matrices
X_repl_train<-as.matrix(df_repl_train%>%select(splmU55,thoulbs_I,numcrash,weekend,lowviol,highviol,ruralrd,frimp,suv))
X_repl_test<- as.matrix(df_repl_test%>%select(splmU55,thoulbs_I,numcrash,weekend,lowviol,highviol,ruralrd,frimp,suv))
X_sel_train<- as.matrix(df_sel_train%>%select(splmU55,thoulbs_I,numcrash,weekend,lowviol,highviol,ruralrd,frimp,suv))
X_sel_test<-  as.matrix(df_sel_test%>%select(splmU55,thoulbs_I,numcrash,weekend,lowviol,highviol,ruralrd,frimp,suv))
X_repl_train_nocontrols<-as.matrix(rep(1,nrow(X_repl_train)))
X_repl_test_nocontrols<- as.matrix(rep(1,nrow(X_repl_test)))
X_sel_train_nocontrols<- as.matrix(rep(1,nrow(X_sel_train)))
X_sel_test_nocontrols<-  as.matrix(rep(1,nrow(X_sel_test)))
# D matrices
D_repl_train<-factor(df_repl_train$D,levels=c("NONE","Lapbelt","LapShoulderSeat","Childseat"))
D_repl_test<-factor(df_repl_train$D,levels=c("NONE","Lapbelt","LapShoulderSeat","Childseat"))
D_sel_train<-factor(df_sel_train$D,levels=c("NONE","Lapbelt","LapShoulderSeat","Childseat"))
D_sel_test<-factor(df_sel_train$D,levels=c("NONE","Lapbelt","LapShoulderSeat","Childseat"))
D_binary_repl_train<-as.matrix(df_repl_train%>%select(Dbinary))
D_binary_repl_test<- as.matrix(df_repl_test%>%select(Dbinary))
D_binary_sel_train<- as.matrix(df_sel_train%>%select(Dbinary))
D_binary_sel_test<-  as.matrix(df_sel_test%>%select(Dbinary))
# Y matrices
Y_repl_train<-as.matrix(df_repl_train%>%select(death))
Y_repl_test<- as.matrix(df_repl_test%>%select(death))
Y_sel_train<- as.matrix(df_sel_train%>%select(death))
Y_sel_test<-  as.matrix(df_sel_test%>%select(death))

3.2.2 Then get the forest

cfmulti <- multi_arm_causal_forest(X=X_sel_train, Y=Y_sel_train, W=D_sel_train)
average_treatment_effect(cfmulti)
##                           estimate     std.err               contrast outcome
## Lapbelt - NONE         -0.05571726 0.007624041         Lapbelt - NONE   death
## LapShoulderSeat - NONE -0.05396493 0.006828650 LapShoulderSeat - NONE   death
## Childseat - NONE       -0.05746134 0.005700159       Childseat - NONE   death

3.2.3 Propensity scores

3.2.4 Influential features

# Get importance
importance=variable_importance(cfmulti)

var_imp <- data.frame(importance=importance,names=colnames(X_sel_train))
ggplot(var_imp,aes(x= reorder(names,importance),y=importance))+
  geom_bar(stat="identity",fill="#f56c42",color="white")+
  theme_minimal()+
  theme(axis.text.x = element_text(angle=45,vjust = 1, hjust=1))+
  labs(x=" ")+
  coord_flip()

### Characterising treatment effect heterogeneity

CATE distribution

Now by covariates

4 Policy Learning

4.1 Single treatment

## policy_tree object 
## Tree depth:  2 
## Actions:  1: control 2: treated 
## Variable splits: 
## (1) split_variable: thoulbs_I  split_value: 3.128 
##   (2) split_variable: thoulbs_I  split_value: 3.101 
##     (4) * action: 1 
##     (5) * action: 2 
##   (3) split_variable: numcrash  split_value: 17 
##     (6) * action: 1 
##     (7) * action: 2

4.2 Multiarm

## policy_tree object 
## Tree depth:  2 
## Actions:  1: NONE 2: Lapbelt 3: LapShoulderSeat 4: Childseat 
## Variable splits: 
## (1) split_variable: thoulbs_I  split_value: 2.364 
##   (2) split_variable: thoulbs_I  split_value: 2.362 
##     (4) * action: 1 
##     (5) * action: 2 
##   (3) split_variable: thoulbs_I  split_value: 4.34 
##     (6) * action: 1 
##     (7) * action: 3